{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [Module 5](https://dataflowr.github.io/website/modules/5-stacking-layers/): overfitting a MLP on CIFAR10\n",
    "\n",
    "Training loop over CIFAR10 (40,000 train images, 10,000 test images). What happens if you\n",
    "- switch the training to a GPU? Is it faster?\n",
    "- Remove the `ReLU()`? \n",
    "- Increase the learning rate?\n",
    "- Stack more layers? \n",
    "- Perform more epochs?\n",
    "\n",
    "Can you completely overfit the training set (i.e. get 100% accuracy?)\n",
    "\n",
    "This code is highly non-modulable. Create functions for each specific task. \n",
    "(hint: see [this](https://github.com/pytorch/examples/blob/master/mnist/main.py))\n",
    "\n",
    "Your training went well. Good. Why not save the weights of the network (`net.state_dict()`) using `torch.save()`?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as t\n",
    "\n",
    "# define network structure \n",
    "net = nn.Sequential(nn.Linear(3 * 32 * 32, 1000), nn.ReLU(), nn.Linear(1000, 10))\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)\n",
    "\n",
    "# load data\n",
    "to_tensor =  t.ToTensor()\n",
    "normalize = t.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "flatten =  t.Lambda(lambda x:x.view(-1))\n",
    "\n",
    "transform_list = t.Compose([to_tensor, normalize, flatten])\n",
    "train_set = torchvision.datasets.CIFAR10(root='.', train=True, transform=transform_list, download=True)\n",
    "test_set = torchvision.datasets.CIFAR10(root='.', train=False, transform=transform_list, download=True)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size=64)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)\n",
    "\n",
    "# === Train === ###\n",
    "net.train()\n",
    "\n",
    "# train loop\n",
    "for epoch in range(3):\n",
    "    train_correct = 0\n",
    "    train_loss = 0\n",
    "    print('Epoch {}'.format(epoch))\n",
    "    \n",
    "    # loop per epoch \n",
    "    for i, (batch, targets) in enumerate(train_loader):\n",
    "\n",
    "        output = net(batch)\n",
    "        loss = criterion(output, targets)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        pred = output.max(1, keepdim=True)[1]\n",
    "        train_correct += pred.eq(targets.view_as(pred)).sum().item()\n",
    "        train_loss += loss\n",
    "\n",
    "        if i % 100 == 10: print('Train loss {:.4f}, Train accuracy {:.2f}%'.format(\n",
    "            train_loss / ((i+1) * 64), 100 * train_correct / ((i+1) * 64)))\n",
    "        \n",
    "print('End of training.\\n')\n",
    "    \n",
    "# === Test === ###\n",
    "test_correct = 0\n",
    "net.eval()\n",
    "\n",
    "# loop, over whole test set\n",
    "for i, (batch, targets) in enumerate(test_loader):\n",
    "    \n",
    "    output = net(batch)\n",
    "    pred = output.max(1, keepdim=True)[1]\n",
    "    test_correct += pred.eq(targets.view_as(pred)).sum().item()\n",
    "    \n",
    "print('End of testing. Test accuracy {:.2f}%'.format(\n",
    "    100 * test_correct / (len(test_loader) * 64)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Autograd tips and tricks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pointers are everywhere!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Linear(2, 2)\n",
    "w = net.weight\n",
    "print(w)\n",
    "\n",
    "x = torch.rand(1, 2)\n",
    "y = net(x).sum()\n",
    "y.backward()\n",
    "net.weight.data -= 0.01 * net.weight.grad # <--- What is this?\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Linear(2, 2)\n",
    "w = net.weight.clone()\n",
    "print(w)\n",
    "\n",
    "x = torch.rand(1, 2)\n",
    "y = net(x).sum()\n",
    "y.backward()\n",
    "net.weight.data -= 0.01 * net.weight.grad # <--- What is this?\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sharing weights "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))\n",
    "net[0].weight = net[1].weight  # weight sharing\n",
    "\n",
    "x = torch.rand(1, 2)\n",
    "y = net(x).sum()\n",
    "y.backward()\n",
    "print(net[0].weight.grad)\n",
    "print(net[1].weight.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
